# analyzing hierarhical LASSO output
library(tidyverse)
library(data.table)
library(hierNet)
library(colorblindr)

source("pundits_functions.R")

set.seed(1111)

get_mse <- function(obs, pred){
  mean((obs-pred)^2)
}

get_mae <- function(obs, pred){
  mean(abs(obs-pred))
}

topic_refdf <- data.frame(formal_names = c("China","Class","Climate",
                                           "Conservative","Democrat",
                                           "Far Left","Far Right",
                                           "Gender","Guns","Health Care/Insurance",
                                           "Immigration","Iran","Israel","LGBT",
                                           "Liberal","Mueller","Progressive",
                                           "Race","Reproductive Health","Republican",
                                           "Taxes and Spending",
                                           "Trade"),
                          topic = c('china','class','climate','conservative',
                                    'democrat',
                                    'far_left','far_right','gender','guns',
                                    'health_care_insurance','immigration',
                                    'iran','israel','lgbt','liberal',
                                    'mueller','progressive','race',
                                    'reproductive_health','republican',
                                    'taxes_spending','trade'),
                          group = c(2:23))



pundits_agg0 <- readr::read_csv("../data/pundits_parrot_aggregated_imp0.csv")

ideal.points_d1 <- readr::read_csv("../data/ideal_points_masked.csv")

which.na.ts <- which(is.na(ideal.points_d1$tweetscore))

load("../output/mod_allvars.RData")
load("../output/modcv_yhats_allvars.RData")
load("../output/modcv_yhats_allvars_heldout_yhats.RData")

# store prediction errors from each fold of cross-validation
yhat_df <- map_df(1:10, function(x){
  data.frame(index = yhats$folds[[x]],
             yhat = as.numeric(yhats$yhat[[x]]))
})

yhat_df$Y_true <- ideal.points_d1$ideal.point[yhat_df$index]

bestmod <- which.min(mod.cv$cv.err)

# function to calculate coefficient magnitudes
make_mags <- function(mod, bestmod){
  
  th <- mod$th[,,bestmod]
  
  th <- (th+t(th))/2
  
  diag(th) <- 0

  th[lower.tri(th)] <- 0
  
  dfth <- melt(th)
  dfth$Var1 <- rep(names(pundits_agg0)[-1], length(names(pundits_agg0)[-1]))
  dfth$Var2 <- rep(names(pundits_agg0)[-1], each =length(names(pundits_agg0)[-1]))
  
  splits <- strsplit(dfth$Var1, split = "_")
  dfth$topic1 <- sapply(splits, function(x){
    paste0(x[2:length(x)], collapse = "_")
  })
  
  splits <- strsplit(dfth$Var2, split = "_")
  dfth$topic2 <- sapply(splits, function(x){
    paste0(x[2:length(x)], collapse = "_")
  })

  mains <- data.frame(var = names(pundits_agg0)[-1],
                      coef= mod$bp[,bestmod] - mod$bn[,bestmod])
  
  splits <- strsplit(mains$var, split = "_")
  mains$topic <- sapply(splits, function(x){
    paste0(x[2:length(x)], collapse = "_")
  })
  
  main_mags <- 
    mains %>%
    group_by(topic) %>%
    summarise(magnitude = sum(abs(coef))) %>%
    arrange(desc(magnitude))
  
  within_mags <- 
    dfth %>%
    group_by(topic1, topic2) %>%
    summarise(magnitude = sum(abs(value))) %>%
    filter(topic1 == topic2) %>%
    arrange(desc(magnitude))
  
  between_mags <- bind_rows(
    dfth %>%
      filter(!topic1 == topic2) %>%
      group_by(topic1) %>%
      summarise(magnitude = sum(abs(value))) %>%
      rename(topic = topic1),
    dfth %>%
      filter(!topic1 == topic2) %>%
      group_by(topic2) %>%
      summarise(magnitude = sum(abs(value))) %>%
      rename(topic = topic2)
  ) %>%
    group_by(topic) %>%
    summarise(magnitude = sum(magnitude)) %>%
    arrange(desc(magnitude))
    
  mag_total <- 
    main_mags %>%
    rename(mag_main = magnitude) %>%
    left_join(within_mags %>%
                rename(topic = topic1,
                       mag_within = magnitude) %>%
                dplyr::select(topic, mag_within),
              by = "topic") %>%
    left_join(between_mags %>%
                rename(mag_between = magnitude),
              by = "topic")
  return(mag_total)
}

# plot predicted/observed
predplot <- 
  yhat_df %>%
  ggplot(aes(x = Y_true, y = yhat))+
  geom_point(alpha = .5)+
  guides(alpha = "none")+
  stat_smooth(method = 'lm')+
  annotate(geom = "text", x = 3, y = -6, 
           label = paste0("Mean Squared Error: ", 
                          round(mean(with(yhat_df, Y_true - yhat)^2), 3)),
           family = 'serif')+
  labs(title = "Held-out fit",
       x = "Follow-Based Ideal Point\n(Observed)",
       y = "Follow-Based Ideal Point\n(Predicted)")+
  theme_jg()
ggsave(predplot, file = "~/Desktop/pundits/pundits_beliefnetworks/figs/predplot_hierarchical.png", width = 10, height = 6)
predplot$data %>%
  write.csv(file = "~/Desktop/pundits/pundits_beliefnetworks/figs/predplot_hierarchical_data.csv")

# generate coefficient magnitudes
mag_total <- make_mags(mod = mod, bestmod = bestmod)

# plot magnitudes
magplot <- 
  mag_total %>%
  mutate(mag_total = mag_main + mag_within + mag_between) %>%
  arrange(desc(mag_total)) %>%
  dplyr::select(-mag_total) %>%
  reshape2::melt(id.vars = c("topic")) %>%
  left_join(topic_refdf, by = "topic") %>%
  ggplot(aes(x = fct_rev(fct_inorder(formal_names)),
             y = value,
             fill= fct_rev(fct_inorder(variable))))+
  geom_col()+
  scale_fill_OkabeIto(name = "Type",
                      breaks = c("mag_main","mag_within","mag_between"),
                      labels = c("Constituent Terms",
                                 "Interactions (Within Topic)",
                                 "Interactions (Between Topics)"))+
  coord_flip()+
  labs(x = "",
       y = "Magnitude\n(Sum of Coefficient Absolute Values)",
       title= "Concept Importance")+
  theme_jg()+
  theme(legend.position = c(.7, .3),
        legend.direction = "vertical")
ggsave(magplot, file = "~/Desktop/pundits/pundits_beliefnetworks/figs/magplot_constrained.png", width = 10, height = 6)
magplot$data %>%
  write.csv(file = "~/Desktop/pundits/pundits_beliefnetworks/figs/magplot_constrained_data.csv")

# make figure
library(patchwork)
cap1 <- "Panel A: Held-out predictions from hierarchical LASSO with ten folds of cross-validation"
cap2 <- "Panel B: Sum of coefficient absolute values in best-fitting model"

pundits_fig_3 <- predplot + magplot +
  plot_layout(widths = c(1, 1))+
  plot_annotation(tag_levels = "A",
                #  title = "Figure 2. Follow-Based Ideal Points by Average Document Scores",
                  caption = paste0(cap1, "\n", cap2),
                  theme = theme_jg()+
                    theme(plot.caption = element_text(hjust = 0)))

ggsave(pundits_fig_3, file = "../figures/figure_3.png", width = 12, height = 6)
ragg::agg_tiff("../figures/figure_3.tiff", 
    width=12, height=6, units = "in", 
    res = 300)
print(pundits_fig_3)
dev.off()

# print concept-wise coefficient magnitudes by type of coef
mag_total %>%
  mutate(mag_total = mag_main + mag_within + mag_between) %>%
  arrange(desc(mag_main)) %>%
  dplyr::select(-mag_total) %>%
  left_join(topic_refdf, by = "topic") %>%
  dplyr::select(formal_names, mag_main, mag_within, mag_between) %>%
  rename(Concept = formal_names,
         Magnitude_Main = mag_main,
         Magnitude_Within = mag_within,
         Magnitude_Between = mag_between) %>%
  mutate_at(.vars = dplyr::vars(starts_with("Magnitude")),
            .funs = function(x){
              round(x, 3)
            }) %>%
  knitr::kable()

# repeat for tweetscore outcome (Figure H.1)
load("../output/mod_allvars_tweetscore.RData")
load("../output/modcv_yhats_allvars_tweetscore.RData")

bestmod <- which(mod.cv$lamlist == mod.cv$lamhat.1se)

yhats <- mod.cv$yhat[,c("index","fold",paste0("yhat.",bestmod))]
names(yhats) <- c("index","fold","yhat")
yhats$Y_true <- ideal.points_d1$tweetscore[-which.na.ts][mod.cv$yhat$index]

predplot_ts <- 
  yhats %>%
  ggplot(aes(x = Y_true, y = yhat))+
  geom_point(alpha = .5)+
  stat_smooth(method = "lm")+
  guides(alpha = "none")+
  labs(title = "Latent Text Dimensions and TweetScores",
       subtitle = "Held-out predictions from hierarchical LASSO with ten folds of cross-validation",
       x = "TweetScore\n(Observed)",
       y = "TweetScore\n(Predicted)",
       caption = paste0("Mean Squared Error: ", 
                        round(
                          mean(with(yhats, Y_true - yhat)^2), 3), 
                        "\nCorrelation between predicted and observed: ",
                        round(with(yhats, cor(Y_true, yhat)), 3)))+
  theme_jg()
ggsave(predplot_ts, file = "../figures/predplot_ts_hierarchical.png", width = 10, height = 6)

mag_total_ts <- make_mags(mod = mod, bestmod = bestmod)

mag_total_ts %>%
  mutate(mag_total = mag_main + mag_within + mag_between) %>%
  arrange(desc(mag_main)) %>%
  dplyr::select(-mag_total) %>%
  left_join(topic_refdf, by = "topic") %>%
  dplyr::select(formal_names, mag_main, mag_within, mag_between) %>%
  rename(Concept = formal_names,
         Magnitude_Main = mag_main,
         Magnitude_Within = mag_within,
         Magnitude_Between = mag_between) %>%
  mutate_at(.vars = dplyr::vars(starts_with("Magnitude")),
            .funs = function(x){
              round(x, 3)
            }) %>%
  knitr::kable()
